/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ambari.logsearch.web.filters;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import io.jsonwebtoken.SignatureException;
import org.apache.ambari.logsearch.conf.AuthPropsConfig;
import org.apache.ambari.logsearch.web.model.JWTAuthenticationToken;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;

public class LogsearchJWTFilter extends AbstractAuthenticationProcessingFilter {

  private static final Logger LOG = LoggerFactory.getLogger(LogsearchJWTFilter.class);

  private static final String PEM_HEADER = "-----BEGIN CERTIFICATE-----\n";
  private static final String PEM_FOOTER = "\n-----END CERTIFICATE-----";

  private AuthPropsConfig authPropsConfig;

  public LogsearchJWTFilter(RequestMatcher requestMatcher, AuthPropsConfig authPropsConfig) {
    super(new NegatedRequestMatcher(requestMatcher));
    this.authPropsConfig = authPropsConfig;
  }

  @Override
  public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException {
    if (StringUtils.isEmpty(authPropsConfig.getProvidedUrl())) {
      throw new BadCredentialsException("Authentication provider URL must not be null or empty.");
    }
    if (StringUtils.isEmpty(authPropsConfig.getPublicKey())) {
      throw new BadCredentialsException("Public key for signature validation must be provisioned.");
    }

    try {
      Claims claims = Jwts
        .parser()
        .setSigningKey(parseRSAPublicKey(authPropsConfig.getPublicKey()))
        .parseClaimsJws(getJWTFromCookie(request))
        .getBody();

      String userName  = claims.getSubject();
      LOG.info("USERNAME: " + userName);
      LOG.info("URL = " + request.getRequestURL());
      if (StringUtils.isNotEmpty(claims.getAudience()) && !authPropsConfig.getAudiences().contains(claims.getAudience())) {
        throw new IllegalArgumentException(String.format("Audience validation failed. (Not found: %s)", claims.getAudience()));
      }
      Authentication authentication = new JWTAuthenticationToken(userName, authPropsConfig.getPublicKey());
      authentication.setAuthenticated(true);
      SecurityContextHolder.getContext().setAuthentication(authentication);
      return authentication;
    } catch (ExpiredJwtException | MalformedJwtException | SignatureException | IllegalArgumentException e) {
      LOG.info("URL = " + request.getRequestURL());
      LOG.warn("Error during JWT authentication: ", e.getMessage());
      throw new BadCredentialsException(e.getMessage(), e);
    }
  }

  @Override
  public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
    Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
    if (!authPropsConfig.isAuthJwtEnabled() || isAuthenticated(authentication)) {
      chain.doFilter(req, res);
      return;
    }
    super.doFilter(req, res, chain);
  }

  @Override
  protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authResult) throws IOException, ServletException {
    super.successfulAuthentication(request, response, chain, authResult);
    response.sendRedirect(request.getRequestURL().toString() + getOriginalQueryString(request));
  }

  @Override
  protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException {
    super.unsuccessfulAuthentication(request, response, failed);
    String loginUrl = constructLoginURL(request);
    response.sendRedirect(loginUrl);
  }

  private String getJWTFromCookie(HttpServletRequest req) {
    String serializedJWT = null;
    Cookie[] cookies = req.getCookies();
    if (cookies != null) {
      for (Cookie cookie : cookies) {
        if (authPropsConfig.getCookieName().equals(cookie.getName())) {
          LOG.info(authPropsConfig.getCookieName() + " cookie has been found and is being processed");
          serializedJWT = cookie.getValue();
          break;
        }
      }
    }
    return serializedJWT;
  }

  private RSAPublicKey parseRSAPublicKey(String pem) throws ServletException {
    String fullPem = PEM_HEADER + pem + PEM_FOOTER;
    try {
      CertificateFactory fact = CertificateFactory.getInstance("X.509");
      ByteArrayInputStream is = new ByteArrayInputStream(fullPem.getBytes("UTF8"));

      X509Certificate cer = (X509Certificate) fact.generateCertificate(is);
      return (RSAPublicKey) cer.getPublicKey();
    } catch (CertificateException ce) {
      String message;
      if (pem.startsWith(PEM_HEADER)) {
        message = "CertificateException - be sure not to include PEM header "
          + "and footer in the PEM configuration element.";
      } else {
        message = "CertificateException - PEM may be corrupt";
      }
      throw new ServletException(message, ce);
    } catch (UnsupportedEncodingException uee) {
      throw new ServletException(uee);
    }
  }

  private String constructLoginURL(HttpServletRequest request) {
    String delimiter = "?";
    if (authPropsConfig.getProvidedUrl().contains("?")) {
      delimiter = "&";
    }
    return authPropsConfig.getProvidedUrl() + delimiter
      + authPropsConfig.getOriginalUrlQueryParam() + "="
      + request.getRequestURL().toString() + getOriginalQueryString(request);
  }

  private String getOriginalQueryString(HttpServletRequest request) {
    String originalQueryString = request.getQueryString();
    return (originalQueryString == null) ? "" : "?" + originalQueryString;
  }

  private boolean isAuthenticated(Authentication authentication) {
    return authentication != null && !(authentication instanceof AnonymousAuthenticationToken) && authentication.isAuthenticated();
  }

}
